"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt)
    dataset = data_loader.load_data()
    return dataset


class CustomDatasetDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, opt):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        """
        self.opt = opt
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        self.dataset_ori = dataset_class(opt)
        print("dataset [%s] was created" % type(self.dataset_ori).__name__)
        if self.dataset_ori.__len__() > 0:
            if 'val' in self.opt.__dir__():
                # self.train_set, self.val_set = self.split(self.opt.val_ratio)
                self.set_train_val()
            else:
                self.dataloader_ori = torch.utils.data.DataLoader(
                self.dataset_ori,
                batch_size=opt.batch_size,
                shuffle=not opt.serial_batches,
                num_workers=int(opt.num_threads))
                self.set_mode('ori')
            pass
    
    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(getattr(self,'dataset_'+self.mode)), self.opt.max_dataset_size)

    def __iter__(self):
        """Return a batch of data"""
        # if self.mode == 'train_train':
        #     for i, data in enumerate(self.train_set):
        #         if i * self.opt.batch_size >= self.opt.max_dataset_size:
        #             break
        #         yield data
        # elif self.mode == 'train_val':
        #     for i, data in enumerate(self.val_set):
        #         if i * self.opt.batch_size >= self.opt.max_dataset_size:
        #             break
        #         yield data
        # else:
            # for i, data in enumerate(self.dataloader):
            #     if i * self.opt.batch_size >= self.opt.max_dataset_size:
            #         break
            #     yield data
        for i, data in enumerate(getattr(self,'dataloader_'+self.mode)):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data
            
    def set_mode(self,mode):
        self.mode = mode
    
    def set_train_val(self):
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(self.dataset_ori,[int(len(self.dataset_ori)*self.opt.val_ratio),len(self.dataset_ori) - int(len(self.dataset_ori)*self.opt.val_ratio)])
        self.dataloader_train = torch.utils.data.DataLoader(
        self.dataset_train,
        batch_size=self.opt.batch_size,
        shuffle=not self.opt.serial_batches,
        num_workers=int(self.opt.num_threads))
        
        self.dataloader_val = torch.utils.data.DataLoader(
        self.dataset_val,
        batch_size=self.opt.batch_size,
        shuffle=not self.opt.serial_batches,
        num_workers=int(self.opt.num_threads))
        
        self.set_mode('train')
    # def split(self, split=0.2):
    #     """Splits the given dataset into training/validation.
    #     Args:
    #         dataset[torch dataloader]: Dataset which has to be split
    #         batch_size[int]: Batch size
    #         split[float]: Indicates ratio of validation samples
    #     Returns:
    #         train_set[list]: Training set
    #         val_set[list]: Validation set
    #     """

    #     index = 0
    #     length = len(self.dataloader)

    #     train_set = []
    #     val_set = []

    #     for data, target in self.dataloader:
    #         if index <= (length * split):
    #             train_set.append([data, target])
    #         else:
    #             val_set.append([data, target])

    #         index += 1

    #     return train_set, val_set